# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union

from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn

from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


@MODELS.register_module()
class DGCNNBackbone(BaseModule):
    """Backbone network for DGCNN.

    Args:
        in_channels (int): Input channels of point cloud.
        num_samples (tuple[int], optional): The number of samples for knn or
            ball query in each graph feature (GF) module.
            Defaults to (20, 20, 20).
        knn_modes (tuple[str], optional): Mode of KNN of each knn module.
            Defaults to ('D-KNN', 'F-KNN', 'F-KNN').
        radius (tuple[float], optional): Sampling radii of each GF module.
            Defaults to (None, None, None).
        gf_channels (tuple[tuple[int]], optional): Out channels of each mlp in
            GF module. Defaults to ((64, 64), (64, 64), (64, )).
        fa_channels (tuple[int], optional): Out channels of each mlp in FA
            module. Defaults to (1024, ).
        act_cfg (dict, optional): Config of activation layer.
            Defaults to dict(type='ReLU').
        init_cfg (dict, optional): Initialization config.
            Defaults to None.
    """

    def __init__(self,
                 in_channels: int,
                 num_samples: Sequence[int] = (20, 20, 20),
                 knn_modes: Sequence[str] = ('D-KNN', 'F-KNN', 'F-KNN'),
                 radius: Sequence[Union[float, None]] = (None, None, None),
                 gf_channels: Sequence[Sequence[int]] = ((64, 64), (64, 64),
                                                         (64, )),
                 fa_channels: Sequence[int] = (1024, ),
                 act_cfg: ConfigType = dict(type='ReLU'),
                 init_cfg: OptMultiConfig = None):
        super().__init__(init_cfg=init_cfg)
        self.num_gf = len(gf_channels)

        assert len(num_samples) == len(knn_modes) == len(radius) == len(
            gf_channels), 'Num_samples, knn_modes, radius and gf_channels \
            should have the same length.'

        self.GF_modules = nn.ModuleList()
        gf_in_channel = in_channels * 2
        skip_channel_list = [gf_in_channel]  # input channel list

        for gf_index in range(self.num_gf):
            cur_gf_mlps = list(gf_channels[gf_index])
            cur_gf_mlps = [gf_in_channel] + cur_gf_mlps
            gf_out_channel = cur_gf_mlps[-1]

            self.GF_modules.append(
                DGCNNGFModule(
                    mlp_channels=cur_gf_mlps,
                    num_sample=num_samples[gf_index],
                    knn_mode=knn_modes[gf_index],
                    radius=radius[gf_index],
                    act_cfg=act_cfg))
            skip_channel_list.append(gf_out_channel)
            gf_in_channel = gf_out_channel * 2

        fa_in_channel = sum(skip_channel_list[1:])
        cur_fa_mlps = list(fa_channels)
        cur_fa_mlps = [fa_in_channel] + cur_fa_mlps

        self.FA_module = DGCNNFAModule(
            mlp_channels=cur_fa_mlps, act_cfg=act_cfg)

    def forward(self, points: Tensor) -> dict:
        """Forward pass.

        Args:
            points (torch.Tensor): point coordinates with features,
                with shape (B, N, in_channels).

        Returns:
            dict[str, list[torch.Tensor]]: Outputs after graph feature (GF) and
                feature aggregation (FA) modules.

                - gf_points (list[torch.Tensor]): Outputs after each GF module.
                - fa_points (torch.Tensor): Outputs after FA module.
        """
        gf_points = [points]

        for i in range(self.num_gf):
            cur_points = self.GF_modules[i](gf_points[i])
            gf_points.append(cur_points)

        fa_points = self.FA_module(gf_points)

        out = dict(gf_points=gf_points, fa_points=fa_points)
        return out
